import numpy as np
import matplotlib.pyplot as plt


img_data = np.load('pred/total_pred.npy')

for sample in range(img_data.shape[0]):
    for slice in range(img_data.shape[2]):
        img = np.zeros([img_data.shape[3], img_data.shape[4]])
        for coil in range(img_data.shape[1]):
            temp = img_data[sample, coil, slice]
            img = img + temp ** 2
        img = np.sqrt(img)

        plt.ion()
        plt.figure()
        plt.imshow(img, cmap='gray')
        plt.savefig(f'pred/{sample}__{slice}.png')
        plt.axis('off')
        plt.ioff()
